{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-vLwpT31YOJk"
      },
      "source": [
        "TODO(b/138297412): This colab retains some useful code snippets and demonstrations that used to be in the tf.function/AutoGraph customization tutorial, and should be rolled into the existing docs as part of a broader markdown-\u003ecolab conversion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "otIdN1TS8N7S"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "I0xDjO4SHLUD"
      },
      "source": [
        "Define a helper function to demonstrate the kinds of errors you might encounter:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "D25apou9IOXa"
      },
      "outputs": [],
      "source": [
        "import traceback\n",
        "import contextlib\n",
        "\n",
        "# Some helper code to demonstrate the kinds of errors you might encounter.\n",
        "@contextlib.contextmanager\n",
        "def assert_raises(error_class):\n",
        "  try:\n",
        "    yield\n",
        "  except error_class as e:\n",
        "    print('Caught expected exception \\n  {}:'.format(error_class))\n",
        "    traceback.print_exc(limit=2)\n",
        "  except Exception as e:\n",
        "    raise e\n",
        "  else:\n",
        "    raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
        "        error_class))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5f05Vr_YBUCz"
      },
      "source": [
        "## Using AutoGraph\n",
        "\n",
        "The [autograph](https://www.tensorflow.org/guide/function) library is fully integrated with `tf.function`, and it will rewrite conditionals and loops which depend on Tensors to run dynamically in the graph.\n",
        "\n",
        "`tf.cond` and `tf.while_loop` continue to work with `tf.function`, but code with control flow is often easier to write and understand when written in imperative style."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xgKmkrNTZSyz"
      },
      "source": [
        "## AutoGraph: Conditionals\n",
        "\n",
        "AutoGraph will convert `if` statements into the equivalent `tf.cond` calls.\n",
        "\n",
        "This substitution is made if the condition is a Tensor. Otherwise, the conditional is executed during tracing."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "20WlM9T2I9EV"
      },
      "source": [
        "Here is a function that checks if the resulting graph uses `tf.cond`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "E-7KllizZYsy"
      },
      "outputs": [],
      "source": [
        "def test_tf_cond(f, *args):\n",
        "  g = f.get_concrete_function(*args).graph\n",
        "  if any(node.name == 'cond' for node in g.as_graph_def().node):\n",
        "    print(\"{}({}) uses tf.cond.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n",
        "  else:\n",
        "    print(\"{}({}) executes normally.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n",
        "\n",
        "  print(\"  result: \",f(*args).numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DlqiutEEJHOe"
      },
      "source": [
        "This substitution is made if the condition is a Tensor. Otherwise, the conditional is executed during tracing.\n",
        "\n",
        "Passing a python `True` executes the conditional normally:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "fCMywOXwJLIQ"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def dropout(x, training=True):\n",
        "  if training:\n",
        "    x = tf.nn.dropout(x, rate=0.5)\n",
        "  return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "68D2RZ17JM8u"
      },
      "outputs": [],
      "source": [
        "test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WEz0QYucJPBa"
      },
      "source": [
        "But passing a tensor replaces the python `if` with a `tf.cond`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "o86paGR-Zadi"
      },
      "outputs": [],
      "source": [
        "test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5xFLfdApZh8q"
      },
      "source": [
        "`tf.cond` has a number of subtleties.\n",
        "\n",
        "it works by tracing both sides of the conditional, and then choosing the appropriate branch at runtime, depending on the condition. Tracing both sides can result in unexpected execution of Python code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "VTMoZEVaZiwk"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def f(x):\n",
        "  if x \u003e 0:\n",
        "    x = x + 1.\n",
        "    print(\"Tracing `then` branch\")\n",
        "  else:\n",
        "    x = x - 1.\n",
        "    print(\"Tracing `else` branch\")\n",
        "  return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "HqBVIZWb0Qzn"
      },
      "outputs": [],
      "source": [
        "f(-1.0).numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "BIMfbXlW0QdP"
      },
      "outputs": [],
      "source": [
        "f(1.0).numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2nBnJ42v0Pvq"
      },
      "outputs": [],
      "source": [
        "f(tf.constant(1.0)).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zyzzvtN5Jfpb"
      },
      "source": [
        "It requires that if one branch creates a tensor used downstream, the other branch must also create that tensor."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "k_dxWHeFZlaQ"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def f():\n",
        "  if tf.constant(True):\n",
        "    x = tf.ones([3, 3])\n",
        "  return x\n",
        "\n",
        "# Throws an error because both branches need to define `x`.\n",
        "with assert_raises(ValueError):\n",
        "  f()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wP-LZP6cztnu"
      },
      "source": [
        "If you want to be sure that a particular section of control flow is never converted by autograph, then explicitly convert the object to a python type so an error is raised instead: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "iG_VDavjzrzV"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def f(x, y):\n",
        "  if bool(x):\n",
        "    y = y + 1.\n",
        "    print(\"Tracing `then` branch\")\n",
        "  else:\n",
        "    y = y - 1.\n",
        "    print(\"Tracing `else` branch\")\n",
        "  return y"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kQ4CRP9T0rH2"
      },
      "outputs": [],
      "source": [
        "f(True, 0).numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ww9tCzHy0rkv"
      },
      "outputs": [],
      "source": [
        "f(False, 0).numpy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ppuV7iug0r7i"
      },
      "outputs": [],
      "source": [
        "with assert_raises(TypeError):\n",
        "  f(tf.constant(True), 0.0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yho4J0a0ZkQS"
      },
      "source": [
        "## AutoGraph and loops\n",
        "\n",
        "AutoGraph has a few simple rules for converting loops.\n",
        "\n",
        "- `for`: Convert if the iterable is a tensor\n",
        "- `while`: Convert if the while condition depends on a tensor\n",
        "\n",
        "If a loop is converted, it will be dynamically unrolled with `tf.while_loop`, or in the special case of a `for x in tf.data.Dataset`, transformed into `tf.data.Dataset.reduce`.\n",
        "\n",
        "If a loop is _not_ converted, it will be statically unrolled "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "OyzGNQAuZsky"
      },
      "outputs": [],
      "source": [
        "def test_dynamically_unrolled(f, *args):\n",
        "  g = f.get_concrete_function(*args).graph\n",
        "  if any(node.name == 'while' for node in g.as_graph_def().node):\n",
        "    print(\"{}({}) uses tf.while_loop.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n",
        "  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):\n",
        "    print(\"{}({}) uses tf.data.Dataset.reduce.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n",
        "  else:\n",
        "    print(\"{}({}) gets unrolled.\".format(\n",
        "        f.__name__, ', '.join(map(str, args))))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KFO1BSN9JkRP"
      },
      "source": [
        "### For loops\n",
        "\n",
        "Here is a `tf.function` that demonstrates static unrolling:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "frecgTco_00V"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def for_in_range():\n",
        "  x = 0\n",
        "  for i in range(5):\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(for_in_range)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PMdl0azc_5d4"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def for_in_tfrange():\n",
        "  x = tf.constant(0, dtype=tf.int32)\n",
        "  for i in tf.range(5):\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(for_in_tfrange)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Q7tmncQTZt6_"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def for_in_tfdataset():\n",
        "  x = tf.constant(0, dtype=tf.int64)\n",
        "  for i in tf.data.Dataset.range(5):\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(for_in_tfdataset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "eyPzDYiJAC8f"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def while_py_cond():\n",
        "  x = 5\n",
        "  while x \u003e 0:\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_py_cond)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "l6s7aU-padY5"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def while_tf_cond():\n",
        "  x = tf.constant(5)\n",
        "  while x \u003e 0:\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_tf_cond)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dSr64Xn6ap-S"
      },
      "source": [
        " If you have a `break` or early `return` clause that depends on a tensor, the top-level condition or iterable should also be a tensor.\n",
        "\n",
        "Compare the following examples:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "hG2Fe_OEAwpY"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def while_py_true_py_break(x):\n",
        "  while True:  # py true\n",
        "    if x == 0: # py break\n",
        "      break\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_py_true_py_break, 5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Sr2cn5bY_E_9"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def buggy_while_py_true_tf_break(x):\n",
        "  while True:   # py true\n",
        "    if tf.equal(x, 0): # tf break\n",
        "      break\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "with assert_raises(TypeError):\n",
        "  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Q-VirD-5avdZ"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def while_tf_true_tf_break(x):\n",
        "  while tf.constant(True): # tf true\n",
        "    if x == 0:  # py break\n",
        "      break\n",
        "    x -= 1\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(while_tf_true_tf_break, 5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Upx5J0j8_Ldu"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def buggy_py_for_tf_break():\n",
        "  x = 0\n",
        "  for i in range(5):  # py for\n",
        "    if tf.equal(i, 3): # tf break\n",
        "      break\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "with assert_raises(TypeError):\n",
        "  test_dynamically_unrolled(buggy_py_for_tf_break)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "GQHbodav_QMt"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def tf_for_py_break():\n",
        "  x = 0\n",
        "  for i in tf.range(5): # tf for\n",
        "    if i == 3:  # py break\n",
        "      break\n",
        "    x += i\n",
        "  return x\n",
        "\n",
        "test_dynamically_unrolled(tf_for_py_break)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hyksHW9TCukR"
      },
      "source": [
        "In order to accumulate results from a dynamically unrolled loop, you'll want to use `tf.TensorArray`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "HJ3Vb3dXfefN"
      },
      "outputs": [],
      "source": [
        "batch_size = 2\n",
        "seq_len = 3\n",
        "feature_size = 4\n",
        "\n",
        "def rnn_step(inp, state):\n",
        "  return inp + state\n",
        "\n",
        "@tf.function\n",
        "def dynamic_rnn(rnn_step, input_data, initial_state):\n",
        "  # [batch, time, features] -\u003e [time, batch, features]\n",
        "  input_data = tf.transpose(input_data, [1, 0, 2])\n",
        "  max_seq_len = input_data.shape[0]\n",
        "\n",
        "  states = tf.TensorArray(tf.float32, size=max_seq_len)\n",
        "  state = initial_state\n",
        "  for i in tf.range(max_seq_len):\n",
        "    state = rnn_step(input_data[i], state)\n",
        "    states = states.write(i, state)\n",
        "  return tf.transpose(states.stack(), [1, 0, 2])\n",
        "  \n",
        "dynamic_rnn(rnn_step,\n",
        "            tf.random.uniform([batch_size, seq_len, feature_size]),\n",
        "            tf.zeros([batch_size, feature_size]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9gmLpHY-bkly"
      },
      "source": [
        "### Gotcha's\n",
        "\n",
        "As with `tf.cond`, `tf.while_loop` also comes with a number of subtleties.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FJdfznhhKO7D"
      },
      "source": [
        "#### Zero iterations\n",
        "\n",
        "Since a loop can execute 0 times, all tensors used downstream of the while_loop must be initialized above the loop.\n",
        "\n",
        "Here is an example of incorrect code:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CocT5RHwblrQ"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def buggy_loop_var_uninitialized():\n",
        "  for i in tf.range(3):\n",
        "    x = i\n",
        "  return x\n",
        "\n",
        "with assert_raises(ValueError):\n",
        "  buggy_loop_var_uninitialized()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ncr7tRZ1KWh9"
      },
      "source": [
        "And the correct version:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Wm7wIKXcCDGf"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def f():\n",
        "  x = tf.constant(0)\n",
        "  for i in tf.range(3):\n",
        "    x = i\n",
        "  return x\n",
        "\n",
        "f()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "CM7qXVY0KZHB"
      },
      "source": [
        "#### Consistent shapes and types\n",
        "\n",
        "The shape/dtypes of all loop variables must stay consistent with each iteration.\n",
        "\n",
        "Here is an incorrect example that attempts to change a tensor's type:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "FSftc9cCbpAo"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def buggy_loop_type_changes():\n",
        "  x = tf.constant(0, dtype=tf.float32)\n",
        "  for i in tf.range(3): # Yields tensors of type tf.int32...\n",
        "    x = i\n",
        "  return x\n",
        "\n",
        "with assert_raises(TypeError):\n",
        "  buggy_loop_type_changes()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "M5l90NAHKsUM"
      },
      "source": [
        "Here is an incorrect example that attempts to change a Tensor's shape while iterating:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "kWF189prbuK0"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def buggy_concat():\n",
        "  x = tf.ones([0, 10])\n",
        "  for i in tf.range(5):\n",
        "    x = tf.concat([x, tf.ones([1, 10])], axis=0)\n",
        "  return x\n",
        "\n",
        "with assert_raises(ValueError):\n",
        "  buggy_concat()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "miYnYcznCHeV"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def concat_with_padding():\n",
        "  x = tf.zeros([5, 10])\n",
        "  for i in tf.range(5):\n",
        "    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)\n",
        "    x.set_shape([5, 10])\n",
        "  return x\n",
        "\n",
        "concat_with_padding()\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "performance.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
